{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 线性回归的从零开始实现\n",
    ":label:`sec_linear_scratch`\n",
    "\n",
    "在了解线性回归的关键思想之后，我们可以开始通过代码来动手实现线性回归了。\n",
    "在这一节中，(**我们将从零开始实现整个方法，\n",
    "包括数据流水线、模型、损失函数和小批量随机梯度下降优化器**)。\n",
    "虽然现代的深度学习框架几乎可以自动化地进行所有这些工作，但从零开始实现可以确保你真正知道自己在做什么。\n",
    "同时，了解更细致的工作原理将方便我们自定义模型、自定义层或自定义损失函数。\n",
    "在这一节中，我们将只使用张量和自动求导。\n",
    "在之后的章节中，我们会充分利用深度学习框架的优势，介绍更简洁的实现方式。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: d2l in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (0.17.1)\n",
      "Requirement already satisfied: numpy==1.18.5 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from d2l) (1.18.5)\n",
      "Requirement already satisfied: jupyter==1.0.0 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from d2l) (1.0.0)\n",
      "Requirement already satisfied: matplotlib==3.3.3 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from d2l) (3.3.3)\n",
      "Requirement already satisfied: pandas==1.2.2 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from d2l) (1.2.2)\n",
      "Requirement already satisfied: requests==2.25.1 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from d2l) (2.25.1)\n",
      "Requirement already satisfied: nbconvert in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from jupyter==1.0.0->d2l) (5.4.0)\n",
      "Requirement already satisfied: ipywidgets in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from jupyter==1.0.0->d2l) (7.4.1)\n",
      "Requirement already satisfied: jupyter-console in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from jupyter==1.0.0->d2l) (5.2.0)\n",
      "Requirement already satisfied: qtconsole in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from jupyter==1.0.0->d2l) (4.4.1)\n",
      "Requirement already satisfied: notebook in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from jupyter==1.0.0->d2l) (5.6.0)\n",
      "Requirement already satisfied: ipykernel in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from jupyter==1.0.0->d2l) (4.9.0)\n",
      "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from matplotlib==3.3.3->d2l) (3.0.6)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from matplotlib==3.3.3->d2l) (8.4.0)\n",
      "Requirement already satisfied: python-dateutil>=2.1 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from matplotlib==3.3.3->d2l) (2.7.3)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from matplotlib==3.3.3->d2l) (1.0.1)\n",
      "Requirement already satisfied: cycler>=0.10 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from matplotlib==3.3.3->d2l) (0.10.0)\n",
      "Requirement already satisfied: pytz>=2017.3 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from pandas==1.2.2->d2l) (2018.5)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from requests==2.25.1->d2l) (2021.10.8)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from requests==2.25.1->d2l) (1.23)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from requests==2.25.1->d2l) (2.7)\n",
      "Requirement already satisfied: chardet<5,>=3.0.2 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from requests==2.25.1->d2l) (3.0.4)\n",
      "Requirement already satisfied: six in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from cycler>=0.10->matplotlib==3.3.3->d2l) (1.11.0)\n",
      "Requirement already satisfied: setuptools in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from kiwisolver>=1.0.1->matplotlib==3.3.3->d2l) (40.2.0)\n",
      "Requirement already satisfied: ipython>=4.0.0 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipykernel->jupyter==1.0.0->d2l) (6.5.0)\n",
      "Requirement already satisfied: traitlets>=4.1.0 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipykernel->jupyter==1.0.0->d2l) (4.3.2)\n",
      "Requirement already satisfied: jupyter_client in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipykernel->jupyter==1.0.0->d2l) (5.2.3)\n",
      "Requirement already satisfied: tornado>=4.0 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipykernel->jupyter==1.0.0->d2l) (5.1)\n",
      "Requirement already satisfied: nbformat>=4.2.0 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipywidgets->jupyter==1.0.0->d2l) (4.4.0)\n",
      "Requirement already satisfied: widgetsnbextension~=3.4.0 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipywidgets->jupyter==1.0.0->d2l) (3.4.1)\n",
      "Requirement already satisfied: pygments in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from jupyter-console->jupyter==1.0.0->d2l) (2.2.0)\n",
      "Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.0 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from jupyter-console->jupyter==1.0.0->d2l) (1.0.15)\n",
      "Requirement already satisfied: mistune>=0.8.1 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from nbconvert->jupyter==1.0.0->d2l) (0.8.3)\n",
      "Requirement already satisfied: jinja2 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from nbconvert->jupyter==1.0.0->d2l) (2.10)\n",
      "Requirement already satisfied: jupyter_core in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from nbconvert->jupyter==1.0.0->d2l) (4.4.0)\n",
      "Requirement already satisfied: entrypoints>=0.2.2 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from nbconvert->jupyter==1.0.0->d2l) (0.2.3)\n",
      "Requirement already satisfied: bleach in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from nbconvert->jupyter==1.0.0->d2l) (2.1.4)\n",
      "Requirement already satisfied: pandocfilters>=1.4.1 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from nbconvert->jupyter==1.0.0->d2l) (1.4.2)\n",
      "Requirement already satisfied: testpath in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from nbconvert->jupyter==1.0.0->d2l) (0.3.1)\n",
      "Requirement already satisfied: defusedxml in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from nbconvert->jupyter==1.0.0->d2l) (0.5.0)\n",
      "Requirement already satisfied: Send2Trash in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from notebook->jupyter==1.0.0->d2l) (1.5.0)\n",
      "Requirement already satisfied: prometheus-client in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from notebook->jupyter==1.0.0->d2l) (0.3.1)\n",
      "Requirement already satisfied: terminado>=0.8.1 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from notebook->jupyter==1.0.0->d2l) (0.8.1)\n",
      "Requirement already satisfied: ipython-genutils in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from notebook->jupyter==1.0.0->d2l) (0.2.0)\n",
      "Requirement already satisfied: pyzmq>=17 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from notebook->jupyter==1.0.0->d2l) (17.1.2)\n",
      "Requirement already satisfied: pickleshare in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (0.7.4)\n",
      "Requirement already satisfied: backcall in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (0.1.0)\n",
      "Requirement already satisfied: appnope in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (0.1.0)\n",
      "Requirement already satisfied: pexpect in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (4.6.0)\n",
      "Requirement already satisfied: decorator in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (4.3.0)\n",
      "Requirement already satisfied: simplegeneric>0.8 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (0.8.1)\n",
      "Requirement already satisfied: jedi>=0.10 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (0.12.1)\n",
      "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from nbformat>=4.2.0->ipywidgets->jupyter==1.0.0->d2l) (2.6.0)\n",
      "Requirement already satisfied: wcwidth in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from prompt-toolkit<2.0.0,>=1.0.0->jupyter-console->jupyter==1.0.0->d2l) (0.1.7)\n",
      "Requirement already satisfied: html5lib!=1.0b1,!=1.0b2,!=1.0b3,!=1.0b4,!=1.0b5,!=1.0b6,!=1.0b7,!=1.0b8,>=0.99999999pre in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from bleach->nbconvert->jupyter==1.0.0->d2l) (1.0.1)\n",
      "Requirement already satisfied: MarkupSafe>=0.23 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from jinja2->nbconvert->jupyter==1.0.0->d2l) (1.0)\n",
      "Requirement already satisfied: webencodings in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from html5lib!=1.0b1,!=1.0b2,!=1.0b3,!=1.0b4,!=1.0b5,!=1.0b6,!=1.0b7,!=1.0b8,>=0.99999999pre->bleach->nbconvert->jupyter==1.0.0->d2l) (0.5.1)\n",
      "Requirement already satisfied: parso>=0.3.0 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from jedi>=0.10->ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (0.3.1)\n",
      "Requirement already satisfied: ptyprocess>=0.5 in /Users/wangjiaqi/anaconda3/lib/python3.7/site-packages (from pexpect->ipython>=4.0.0->ipykernel->jupyter==1.0.0->d2l) (0.6.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install d2l\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'torchvision'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-af8953f91218>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0md2l\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0md2l\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torchvision'"
     ]
    }
   ],
   "source": [
    "#导入工具包\n",
    "%matplotlib inline \n",
    "import random\n",
    "import torch\n",
    "import torchvision\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "## 生成数据集\n",
    "\n",
    "为了简单起见，我们将[**根据带有噪声的线性模型构造一个人造数据集。**]\n",
    "我们的任务是使用这个有限样本的数据集来恢复这个模型的参数。\n",
    "我们将使用低维数据，这样可以很容易地将其可视化。\n",
    "在下面的代码中，我们生成一个包含1000个样本的数据集，\n",
    "每个样本包含从标准正态分布中采样的2个特征。\n",
    "我们的合成数据集是一个矩阵$\\mathbf{X}\\in \\mathbb{R}^{1000 \\times 2}$。\n",
    "\n",
    "(**我们使用线性模型参数$\\mathbf{w} = [2, -3.4]^\\top$、$b = 4.2$\n",
    "和噪声项$\\epsilon$生成数据集及其标签：\n",
    "\n",
    "$$\\mathbf{y}= \\mathbf{X} \\mathbf{w} + b + \\mathbf\\epsilon.$$\n",
    "**)\n",
    "\n",
    "你可以将$\\epsilon$视为模型预测和标签时的潜在观测误差。\n",
    "在这里我们认为标准假设成立，即$\\epsilon$服从均值为0的正态分布。\n",
    "为了简化问题，我们将标准差设为0.01。\n",
    "下面的代码生成合成数据集。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 5,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def synthetic_data(w, b, num_examples):  #@save\n",
    "    \"\"\"\n",
    "    生成y=Xw+b+噪声\n",
    "    @para w 权重\n",
    "    @para b 偏差 \n",
    "    @para num_examples 样本数量\n",
    "    @return\n",
    "        X 随机生成的特征数据，(num_examples, len(w))\n",
    "        y X对应的标签 （num_examples，1）\n",
    "    \n",
    "    \"\"\"\n",
    "    X = torch.normal(0, 1, (num_examples, len(w)))#生成均值为0，方差为1，数据纬度是（num_examples, len(w)）的随机数据作为训练样本\n",
    "    y = torch.matmul(X, w) + b #生成X对应的预测值y\n",
    "    y += torch.normal(0, 0.01, y.shape)# 加入噪音，加入的是均值为0，方差为0.01，纬度和y.shape一致的噪音进行干扰\n",
    "    return X, y.reshape((-1, 1))#返回X，y，y为列向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "true_w = torch.tensor([2, -3.4]) #真实权重\n",
    "true_b = 4.2 #真实偏差\n",
    "features, labels = synthetic_data(true_w, true_b, 1000) #随机生成1000组训练数据及标签"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "注意，[**`features`中的每一行都包含一个二维数据样本，\n",
    "`labels`中的每一行都包含一维标签值（一个标量）**]。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 9,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "print('features:', features[0],'\\nlabel:', labels[0]) #打印一个样本数据和对应标签"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "通过生成第二个特征`features[:, 1]`和`labels`的散点图，\n",
    "可以直观观察到两者之间的线性关系。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 11,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#可视化数据\n",
    "d2l.set_figsize() \n",
    "d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1); #x轴为features的第一列，y轴为标签值，正相关"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "## 读取数据集\n",
    "\n",
    "回想一下，训练模型时要对数据集进行遍历，每次抽取一小批量样本，并使用它们来更新我们的模型。\n",
    "由于这个过程是训练机器学习算法的基础，所以有必要定义一个函数，\n",
    "该函数能打乱数据集中的样本并以小批量方式获取数据。\n",
    "\n",
    "在下面的代码中，我们[**定义一个`data_iter`函数，\n",
    "该函数接收批量大小、特征矩阵和标签向量作为输入，生成大小为`batch_size`的小批量**]。\n",
    "每个小批量包含一组特征和标签。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 13,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def data_iter(batch_size, features, labels):\n",
    "    \"\"\"\n",
    "    随机获取一小批样本的数据\n",
    "    @para batch_size 批量的大小\n",
    "    @para features 训练数据\n",
    "    @para labels 训练数据对应的标签\n",
    "    @return\n",
    "        迭代器，每次返回batch_size大小的两组数据，一个是训练样本，一个是对应的标签   \n",
    "        \n",
    "    \"\"\"\n",
    "    num_examples = len(features) #获取样本大小\n",
    "    indices = list(range(num_examples)) #获取样本脚标的list\n",
    "    # 这些样本是随机读取的，没有特定的顺序\n",
    "    random.shuffle(indices) #随机变换indices\n",
    "    for i in range(0, num_examples, batch_size): #开始循环\n",
    "        batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)]) #有可能不能整除，取i + batch_size和num_examples的较小值\n",
    "        yield features[batch_indices], labels[batch_indices] #相当于是一个迭代器，每次返回batch_size个样本"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 15
   },
   "source": [
    "通常，我们利用GPU并行运算的优势，处理合理大小的“小批量”。\n",
    "每个样本都可以并行地进行模型计算，且每个样本损失函数的梯度也可以被并行计算。\n",
    "GPU可以在处理几百个样本时，所花费的时间不比处理一个样本时多太多。\n",
    "\n",
    "我们直观感受一下小批量运算：读取第一个小批量数据样本并打印。\n",
    "每个批量的特征维度显示批量大小和输入特征数。\n",
    "同样的，批量的标签形状与`batch_size`相等。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "batch_size = 10\n",
    "\n",
    "for X, y in data_iter(batch_size, features, labels):\n",
    "    print(X, '\\n', y) #X为10 x 2的tensor，y为10 x 1的tensor\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 17
   },
   "source": [
    "当我们运行迭代时，我们会连续地获得不同的小批量，直至遍历完整个数据集。\n",
    "上面实现的迭代对于教学来说很好，但它的执行效率很低，可能会在实际问题上陷入麻烦。\n",
    "例如，它要求我们将所有数据加载到内存中，并执行大量的随机内存访问。\n",
    "在深度学习框架中实现的内置迭代器效率要高得多，\n",
    "它可以处理存储在文件中的数据和数据流提供的数据。\n",
    "\n",
    "## 初始化模型参数\n",
    "\n",
    "[**在我们开始用小批量随机梯度下降优化我们的模型参数之前**]，\n",
    "(**我们需要先有一些参数**)。\n",
    "在下面的代码中，我们通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重，\n",
    "并将偏置初始化为0。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 19,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "# 由于训练的时候需要更细参数，计算梯度，所以requires_grad=True\n",
    "w = torch.normal(0, 0.01, size=(2,1), requires_grad=True) #w初始化为均值为0，方差为0.001的符合正态分布的数组，纬度为2 x1\n",
    "b = torch.zeros(1, requires_grad=True) #b初始化为0，纬度为1，就是一个实数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 21
   },
   "source": [
    "在初始化参数之后，我们的任务是更新这些参数，直到这些参数足够拟合我们的数据。\n",
    "每次更新都需要计算损失函数关于模型参数的梯度。\n",
    "有了这个梯度，我们就可以向减小损失的方向更新每个参数。\n",
    "因为手动计算梯度很枯燥而且容易出错，所以没有人会手动计算梯度。\n",
    "我们使用 :numref:`sec_autograd`中引入的自动微分来计算梯度。\n",
    "\n",
    "## 定义模型\n",
    "\n",
    "接下来，我们必须[**定义模型，将模型的输入和参数同模型的输出关联起来。**]\n",
    "回想一下，要计算线性模型的输出，\n",
    "我们只需计算输入特征$\\mathbf{X}$和模型权重$\\mathbf{w}$的矩阵-向量乘法后加上偏置$b$。\n",
    "注意，上面的$\\mathbf{Xw}$是一个向量，而$b$是一个标量。\n",
    "回想一下 :numref:`subsec_broadcasting`中描述的广播机制：\n",
    "当我们用一个向量加一个标量时，标量会被加到向量的每个分量上。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 22,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def linreg(X, w, b):  #@save\n",
    "    \"\"\"\n",
    "    线性回归模型\n",
    "    @para X 训练数据（num_examples,len(w)）\n",
    "    @para w 权重 （2，1）\n",
    "    @para b 偏差 实数\n",
    "    @return\n",
    "        模型的预估值\n",
    "    \n",
    "    \"\"\"\n",
    "    return torch.matmul(X, w) + b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 23
   },
   "source": [
    "## [**定义损失函数**]\n",
    "\n",
    "因为需要计算损失函数的梯度，所以我们应该先定义损失函数。\n",
    "这里我们使用 :numref:`sec_linear_regression`中描述的平方损失函数。\n",
    "在实现中，我们需要将真实值`y`的形状转换为和预测值`y_hat`的形状相同。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 24,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def squared_loss(y_hat, y):  #@save\n",
    "    \"\"\"\n",
    "    均方损失\n",
    "    @para y_hat 训练数据的真实值（num，1）\n",
    "    @para y 训练数据的预测值\n",
    "    @return\n",
    "        均方误差，没有除以样本数目 (batch_size,1)\n",
    "    \"\"\"\n",
    "    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 25
   },
   "source": [
    "## (**定义优化算法**)\n",
    "\n",
    "正如我们在 :numref:`sec_linear_regression`中讨论的，线性回归有解析解。\n",
    "尽管线性回归有解析解，但本书中的其他模型却没有。\n",
    "这里我们介绍小批量随机梯度下降。\n",
    "\n",
    "在每一步中，使用从数据集中随机抽取的一个小批量，然后根据参数计算损失的梯度。\n",
    "接下来，朝着减少损失的方向更新我们的参数。\n",
    "下面的函数实现小批量随机梯度下降更新。\n",
    "该函数接受模型参数集合、学习速率和批量大小作为输入。每\n",
    "一步更新的大小由学习速率`lr`决定。\n",
    "因为我们计算的损失是一个批量样本的总和，所以我们用批量大小（`batch_size`）\n",
    "来规范化步长，这样步长大小就不会取决于我们对批量大小的选择。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 27,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def sgd(params, lr, batch_size):  #@save\n",
    "    \"\"\"\n",
    "    小批量随机梯度下降\n",
    "    @para params 参数\n",
    "    @para lr 学习率，人为指定\n",
    "    @para batch_size 批量大小\n",
    "    @return \n",
    "    \"\"\"\n",
    "    with torch.no_grad(): #不需要计算梯度\n",
    "        for param in params:\n",
    "            param -= lr * param.grad / batch_size #梯度下降法更新参数\n",
    "            param.grad.zero_() #手动梯度归零"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 29
   },
   "source": [
    "## 训练\n",
    "\n",
    "现在我们已经准备好了模型训练所有需要的要素，可以实现主要的[**训练过程**]部分了。\n",
    "理解这段代码至关重要，因为从事深度学习后，\n",
    "你会一遍又一遍地看到几乎相同的训练过程。\n",
    "在每次迭代中，我们读取一小批量训练样本，并通过我们的模型来获得一组预测。\n",
    "计算完损失后，我们开始反向传播，存储每个参数的梯度。\n",
    "最后，我们调用优化算法`sgd`来更新模型参数。\n",
    "\n",
    "概括一下，我们将执行以下循环：\n",
    "\n",
    "* 初始化参数\n",
    "* 重复以下训练，直到完成\n",
    "    * 计算梯度$\\mathbf{g} \\leftarrow \\partial_{(\\mathbf{w},b)} \\frac{1}{|\\mathcal{B}|} \\sum_{i \\in \\mathcal{B}} l(\\mathbf{x}^{(i)}, y^{(i)}, \\mathbf{w}, b)$\n",
    "    * 更新参数$(\\mathbf{w}, b) \\leftarrow (\\mathbf{w}, b) - \\eta \\mathbf{g}$\n",
    "\n",
    "在每个*迭代周期*（epoch）中，我们使用`data_iter`函数遍历整个数据集，\n",
    "并将训练数据集中所有样本都使用一次（假设样本数能够被批量大小整除）。\n",
    "这里的迭代周期个数`num_epochs`和学习率`lr`都是超参数，分别设为3和0.03。\n",
    "设置超参数很棘手，需要通过反复试验进行调整。\n",
    "我们现在忽略这些细节，以后会在 :numref:`chap_optimization`中详细介绍。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 30,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "lr = 0.003 #学习率\n",
    "num_epochs = 3 #训练次数\n",
    "net = linreg #网络，之前定义的线性网络\n",
    "loss = squared_loss #损失函数，之前定义的平方损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 32,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#开始训练\n",
    "for epoch in range(num_epochs):\n",
    "    for X, y in data_iter(batch_size, features, labels):\n",
    "        l = loss(net(X, w, b), y)  # X和y的小批量损失\n",
    "        # 因为l形状是(batch_size,1)，而不是一个标量。l中的所有元素被加到一起，\n",
    "        # 并以此计算关于[w,b]的梯度\n",
    "        l.sum().backward()\n",
    "        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数\n",
    "    with torch.no_grad():\n",
    "        train_l = loss(net(features, w, b), labels) #计算所有样本的损失函数\n",
    "        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 34
   },
   "source": [
    "因为我们使用的是自己合成的数据集，所以我们知道真正的参数是什么。\n",
    "因此，我们可以通过[**比较真实参数和通过训练学到的参数来评估训练的成功程度**]。\n",
    "事实上，真实参数和通过训练学到的参数确实非常接近。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 35,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')\n",
    "print(f'b的估计误差: {true_b - b}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 36
   },
   "source": [
    "注意，我们不应该想当然地认为我们能够完美地求解参数。\n",
    "在机器学习中，我们通常不太关心恢复真正的参数，而更关心如何高度准确预测参数。\n",
    "幸运的是，即使是在复杂的优化问题上，随机梯度下降通常也能找到非常好的解。\n",
    "其中一个原因是，在深度网络中存在许多参数组合能够实现高度精确的预测。\n",
    "\n",
    "## 小结\n",
    "\n",
    "* 我们学习了深度网络是如何实现和优化的。在这一过程中只使用张量和自动微分，不需要定义层或复杂的优化器。\n",
    "* 这一节只触及到了表面知识。在下面的部分中，我们将基于刚刚介绍的概念描述其他模型，并学习如何更简洁地实现其他模型。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 如果我们将权重初始化为零，会发生什么。算法仍然有效吗？\n",
    "1. 假设你是[乔治·西蒙·欧姆](https://en.wikipedia.org/wiki/Georg_Ohm)，试图为电压和电流的关系建立一个模型。你能使用自动微分来学习模型的参数吗?\n",
    "1. 您能基于[普朗克定律](https://en.wikipedia.org/wiki/Planck%27s_law)使用光谱能量密度来确定物体的温度吗？\n",
    "1. 如果你想计算二阶导数可能会遇到什么问题？你会如何解决这些问题？\n",
    "1. 为什么在`squared_loss`函数中需要使用`reshape`函数？\n",
    "1. 尝试使用不同的学习率，观察损失函数值下降的快慢。\n",
    "1. 如果样本个数不能被批量大小整除，`data_iter`函数的行为会有什么变化？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 38,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/1778)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
